{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "5_Xm1E_kf5FQ",
    "colab_type": "code",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 251.0
    },
    "outputId": "f9c547c4-7834-451a-a4d0-6b5e7d8c8f0f"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:38: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n",
      "  \"num_layers={}\".format(dropout, num_layers))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1000 cost = 0.003078\n",
      "Epoch: 2000 cost = 0.000852\n",
      "Epoch: 3000 cost = 0.000365\n",
      "Epoch: 4000 cost = 0.000183\n",
      "Epoch: 5000 cost = 0.000099\n",
      "test\n",
      "man -> women\n",
      "mans -> women\n",
      "king -> queen\n",
      "black -> white\n",
      "upp -> down\n"
     ]
    }
   ],
   "source": [
    "# code by Tae Hwan Jung(Jeff Jung) @graykode\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "\n",
    "dtype = torch.FloatTensor\n",
    "# S: Symbol that shows starting of decoding input\n",
    "# E: Symbol that shows starting of decoding output\n",
    "# P: Symbol that will fill in blank sequence if current batch data size is short than time steps\n",
    "\n",
    "char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']\n",
    "num_dic = {n: i for i, n in enumerate(char_arr)}\n",
    "\n",
    "seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]\n",
    "\n",
    "# Seq2Seq Parameter\n",
    "n_step = 5\n",
    "n_hidden = 128\n",
    "n_class = len(num_dic)\n",
    "batch_size = len(seq_data)\n",
    "\n",
    "def make_batch(seq_data):\n",
    "    input_batch, output_batch, target_batch = [], [], []\n",
    "\n",
    "    for seq in seq_data:\n",
    "        for i in range(2):\n",
    "            seq[i] = seq[i] + 'P' * (n_step - len(seq[i]))\n",
    "\n",
    "        input = [num_dic[n] for n in seq[0]]\n",
    "        output = [num_dic[n] for n in ('S' + seq[1])]\n",
    "        target = [num_dic[n] for n in (seq[1] + 'E')]\n",
    "\n",
    "        input_batch.append(np.eye(n_class)[input])\n",
    "        output_batch.append(np.eye(n_class)[output])\n",
    "        target_batch.append(target) # not one-hot\n",
    "\n",
    "    # make tensor\n",
    "    return Variable(torch.Tensor(input_batch)), Variable(torch.Tensor(output_batch)), Variable(torch.LongTensor(target_batch))\n",
    "\n",
    "# Model\n",
    "class Seq2Seq(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Seq2Seq, self).__init__()\n",
    "\n",
    "        self.enc_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)\n",
    "        self.dec_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)\n",
    "        self.fc = nn.Linear(n_hidden, n_class)\n",
    "\n",
    "    def forward(self, enc_input, enc_hidden, dec_input):\n",
    "        enc_input = enc_input.transpose(0, 1) # enc_input: [max_len(=n_step, time step), batch_size, n_class]\n",
    "        dec_input = dec_input.transpose(0, 1) # dec_input: [max_len(=n_step, time step), batch_size, n_class]\n",
    "\n",
    "        # enc_states : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n",
    "        _, enc_states = self.enc_cell(enc_input, enc_hidden)\n",
    "        # outputs : [max_len+1(=6), batch_size, num_directions(=1) * n_hidden(=128)]\n",
    "        outputs, _ = self.dec_cell(dec_input, enc_states)\n",
    "\n",
    "        model = self.fc(outputs) # model : [max_len+1(=6), batch_size, n_class]\n",
    "        return model\n",
    "\n",
    "\n",
    "input_batch, output_batch, target_batch = make_batch(seq_data)\n",
    "\n",
    "model = Seq2Seq()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "for epoch in range(5000):\n",
    "    # make hidden shape [num_layers * num_directions, batch_size, n_hidden]\n",
    "    hidden = Variable(torch.zeros(1, batch_size, n_hidden))\n",
    "\n",
    "    optimizer.zero_grad()\n",
    "    # input_batch : [batch_size, max_len(=n_step, time step), n_class]\n",
    "    # output_batch : [batch_size, max_len+1(=n_step, time step) (becase of 'S' or 'E'), n_class]\n",
    "    # target_batch : [batch_size, max_len+1(=n_step, time step)], not one-hot\n",
    "    output = model(input_batch, hidden, output_batch)\n",
    "    # output : [max_len+1, batch_size, num_directions(=1) * n_hidden]\n",
    "    output = output.transpose(0, 1) # [batch_size, max_len+1(=6), num_directions(=1) * n_hidden]\n",
    "    loss = 0\n",
    "    for i in range(0, len(target_batch)):\n",
    "        # output[i] : [max_len+1, num_directions(=1) * n_hidden, target_batch[i] : max_len+1]\n",
    "        loss += criterion(output[i], target_batch[i])\n",
    "    if (epoch + 1) % 1000 == 0:\n",
    "        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "\n",
    "# Test\n",
    "def translate(word):\n",
    "    input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]])\n",
    "\n",
    "    # make hidden shape [num_layers * num_directions, batch_size, n_hidden]\n",
    "    hidden = Variable(torch.zeros(1, 1, n_hidden))\n",
    "    output = model(input_batch, hidden, output_batch)\n",
    "    # output : [max_len+1(=6), batch_size(=1), n_class]\n",
    "\n",
    "    predict = output.data.max(2, keepdim=True)[1] # select n_class dimension\n",
    "    decoded = [char_arr[i] for i in predict]\n",
    "    end = decoded.index('E')\n",
    "    translated = ''.join(decoded[:end])\n",
    "\n",
    "    return translated.replace('P', '')\n",
    "\n",
    "print('test')\n",
    "print('man ->', translate('man'))\n",
    "print('mans ->', translate('mans'))\n",
    "print('king ->', translate('king'))\n",
    "print('black ->', translate('black'))\n",
    "print('upp ->', translate('upp'))"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Seq2Seq-Torch.ipynb",
   "version": "0.3.2",
   "provenance": [],
   "collapsed_sections": []
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "accelerator": "GPU"
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
